
from utils import DATA_DIR
import os
import json
import re
from LLM_rank.rank_candidate import rank2_allchain_candidate_list , sum_based_rankn
from LLM.openai_0613 import chatgpt_0613
from copy import deepcopy
from tqdm import tqdm
import threading
import queue
import time
import random

answer_dir = os.path.join(DATA_DIR,"answer_G3_singleanswer_12")

rank_output_dir = os.path.join(DATA_DIR,"rank_output")

two_answer_methods = [
    "DFS_woFilter_w2",
    "ETS_annealing_sqrt_newPrompt_s20_f1_t173.72_p0.5_c4_m6_rn3_rg4",
]

one_answer_methods = [
    "CoT@3",
    "BFS_w3_e2",
    "Reflexion@3",
]



LLM_interface = chatgpt_0613(model="gpt-3.5-turbo-16k-0613")







def main():
    task_list = []

    query2answer = {}
    for file in tqdm(os.listdir(answer_dir)):
        pattern = r"(\d+)_([^_]+)_(.+)\.json"
        re_result = re.match(pattern,file)
        assert re_result != None
        query_id = int(re_result.group(1))
        model = re_result.group(2)
        method = re_result.group(3)
        if query_id not in query2answer.keys():
            query2answer[query_id] = {}
        query2answer[query_id][method] = [file,model]

    for query_id in tqdm(query2answer.keys()):
        if len(query2answer[query_id]) < 2:
            continue
        query = None
        functions = None

        candidates = []

        for method in query2answer[query_id].keys():
            if (method not in two_answer_methods) and (method not in one_answer_methods):
                continue
            file_name, model = query2answer[query_id][method]
            with open(os.path.join(answer_dir,file_name),"r") as reader:
                json_data = json.load(reader)
                query = json_data["answer_generation"]["query"]
                functions = json_data["answer_generation"]["function"]

            if method in two_answer_methods:

                if len(json_data["compare_candidates"]) > 1:
                    query_candidates = deepcopy(json_data["compare_candidates"])
                    query_candidates.sort(key = lambda x: x[-1]["Elo"], reverse=True)
                    if query_candidates[0][-1]["Elo"] == query_candidates[-1][-1]["Elo"]:
                        query_candidates = json_data["compare_candidates"]
                
                    candidates.append({"cont":query_candidates[0], "method":method+"_best", "model": model,"confidence": query_candidates[0][-1]["Elo"]})

                    random.shuffle(query_candidates[1:])
                    candidates.append({"cont":query_candidates[1], "method":method+"_random", "model": model,"confidence": query_candidates[1][-1]["Elo"]})
                elif len(json_data["compare_candidates"]) == 1:
                    candidates.append({"cont":json_data["compare_candidates"][0], "method":method+"_best", "model": model,"confidence": json_data["compare_candidates"][0][-1]["Elo"]})
                    candidates.append({"cont":None, "method":method+"_random", "model": model,"confidence": 0.0})
                # else:
                #     candidates.append({"cont":None, "method":method+"_best", "model": model,"confidence": 0.0})
                #     candidates.append({"cont":None, "method":method+"_random", "model": model,"confidence": 0.0})

            elif method in one_answer_methods:
                if len(json_data["compare_candidates"]) > 0:
                    candidate = json_data["compare_candidates"][0]
                    candidates.append({"cont":candidate, "method":method+"_random", "model": model, "confidence": candidate[-1]["Elo"]})
                else:
                    candidates.append({"cont":None, "method":method+"_random", "model": model,"confidence": 0.0})


        # if len(candidates) < 5:
        #     continue



        task_list.append((candidates,query_id,query,functions, query2answer[query_id]))
    return task_list

def deal_task(candidates,query_id, query,functions,method2_file_name,process_id):


    task_description = '''Do the following tasks with function calls, you have access of the follow functions:\n''' + json.dumps(functions,indent=2)

    print(f"[process{process_id}] now doing query {query_id}")

    LLM_interface = chatgpt_0613(model="gpt-3.5-turbo-16k-0613")

    LLM_rank_args = {
        "functions": functions,
        "process_id": process_id,
        "task_description": task_description,
        "input_description": query,
        "rank_func": rank2_allchain_candidate_list,
    }

    scores, total_querys, total_tokens, rank_details = sum_based_rankn(LLM_interface,LLM_rank_args, candidates)
    # print(output)
    output_json_data = {
        "candidates": candidates,
        "rank_detail":  rank_details,
        "methods": [cont["method"] for cont in candidates],
        "before_Elo": [cont["confidence"] for cont in candidates],
        "scores": scores,
    }
    with open(os.path.join(rank_output_dir,f"{query_id}.json"),"w") as writer:
        json.dump(output_json_data, writer, indent=2)
    
class Consumer(threading.Thread):

    def __init__(self, process_id,starting_time):
        super().__init__()
        self.process_id = process_id
        self.starting_time=starting_time

    def run(self):
        global q
        while not q.empty():
            task=q.get()
            print(f"process[{self.process_id}] get task, now task_queue len={q.qsize()}, time_usage={(time.time() - self.starting_time)/60:.2f}min")
            deal_task(*task,process_id=self.process_id)
        print(f"process[{self.process_id}] finish, time_usage={(time.time() - self.starting_time)/60:.2f}min")

if __name__ == "__main__":
    task_list = main()
    process_num = 60

    q=queue.Queue(10000000)
    starting_time = time.time()
    for task in task_list:
        query_id = task[1]
        out_file = os.path.join(rank_output_dir,f"{query_id}.json")
        if not os.path.exists(out_file):
            q.put(task)

    for i in range(process_num):
        p = Consumer(process_id=i,starting_time=starting_time)
        p.start()